{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/MahmoudAshraf97/whisper-diarization/blob/main/Whisper_Transcription_%2B_NeMo_Diarization.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "eCmjcOc9yEtQ"
   },
   "source": [
    "# Installing Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Tn1c-CoDv2kw"
   },
   "outputs": [],
   "source": [
    "!pip install \"faster-whisper>=1.1.0\"\n",
    "!pip install \"nemo-toolkit[asr]>=2.dev\"\n",
    "!pip install git+https://github.com/MahmoudAshraf97/demucs.git\n",
    "!pip install git+https://github.com/oliverguhr/deepmultilingualpunctuation.git\n",
    "!pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git\n",
    "!pip uninstall -y nvidia-cudnn-cu12"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "YzhncHP0ytbQ"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import wget\n",
    "from omegaconf import OmegaConf\n",
    "import json\n",
    "import shutil\n",
    "import torch\n",
    "import torchaudio\n",
    "from nemo.collections.asr.models.msdd_models import NeuralDiarizer\n",
    "from deepmultilingualpunctuation import PunctuationModel\n",
    "import re\n",
    "import logging\n",
    "import nltk\n",
    "import faster_whisper\n",
    "from ctc_forced_aligner import (\n",
    "    load_alignment_model,\n",
    "    generate_emissions,\n",
    "    preprocess_text,\n",
    "    get_alignments,\n",
    "    get_spans,\n",
    "    postprocess_results,\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "jbsUt3SwyhjD"
   },
   "source": [
    "# Helper Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Se6Hc7CZygxu"
   },
   "outputs": [],
   "source": [
    "punct_model_langs = [\n",
    "    \"en\",\n",
    "    \"fr\",\n",
    "    \"de\",\n",
    "    \"es\",\n",
    "    \"it\",\n",
    "    \"nl\",\n",
    "    \"pt\",\n",
    "    \"bg\",\n",
    "    \"pl\",\n",
    "    \"cs\",\n",
    "    \"sk\",\n",
    "    \"sl\",\n",
    "]\n",
    "\n",
    "LANGUAGES = {\n",
    "    \"en\": \"english\",\n",
    "    \"zh\": \"chinese\",\n",
    "    \"de\": \"german\",\n",
    "    \"es\": \"spanish\",\n",
    "    \"ru\": \"russian\",\n",
    "    \"ko\": \"korean\",\n",
    "    \"fr\": \"french\",\n",
    "    \"ja\": \"japanese\",\n",
    "    \"pt\": \"portuguese\",\n",
    "    \"tr\": \"turkish\",\n",
    "    \"pl\": \"polish\",\n",
    "    \"ca\": \"catalan\",\n",
    "    \"nl\": \"dutch\",\n",
    "    \"ar\": \"arabic\",\n",
    "    \"sv\": \"swedish\",\n",
    "    \"it\": \"italian\",\n",
    "    \"id\": \"indonesian\",\n",
    "    \"hi\": \"hindi\",\n",
    "    \"fi\": \"finnish\",\n",
    "    \"vi\": \"vietnamese\",\n",
    "    \"he\": \"hebrew\",\n",
    "    \"uk\": \"ukrainian\",\n",
    "    \"el\": \"greek\",\n",
    "    \"ms\": \"malay\",\n",
    "    \"cs\": \"czech\",\n",
    "    \"ro\": \"romanian\",\n",
    "    \"da\": \"danish\",\n",
    "    \"hu\": \"hungarian\",\n",
    "    \"ta\": \"tamil\",\n",
    "    \"no\": \"norwegian\",\n",
    "    \"th\": \"thai\",\n",
    "    \"ur\": \"urdu\",\n",
    "    \"hr\": \"croatian\",\n",
    "    \"bg\": \"bulgarian\",\n",
    "    \"lt\": \"lithuanian\",\n",
    "    \"la\": \"latin\",\n",
    "    \"mi\": \"maori\",\n",
    "    \"ml\": \"malayalam\",\n",
    "    \"cy\": \"welsh\",\n",
    "    \"sk\": \"slovak\",\n",
    "    \"te\": \"telugu\",\n",
    "    \"fa\": \"persian\",\n",
    "    \"lv\": \"latvian\",\n",
    "    \"bn\": \"bengali\",\n",
    "    \"sr\": \"serbian\",\n",
    "    \"az\": \"azerbaijani\",\n",
    "    \"sl\": \"slovenian\",\n",
    "    \"kn\": \"kannada\",\n",
    "    \"et\": \"estonian\",\n",
    "    \"mk\": \"macedonian\",\n",
    "    \"br\": \"breton\",\n",
    "    \"eu\": \"basque\",\n",
    "    \"is\": \"icelandic\",\n",
    "    \"hy\": \"armenian\",\n",
    "    \"ne\": \"nepali\",\n",
    "    \"mn\": \"mongolian\",\n",
    "    \"bs\": \"bosnian\",\n",
    "    \"kk\": \"kazakh\",\n",
    "    \"sq\": \"albanian\",\n",
    "    \"sw\": \"swahili\",\n",
    "    \"gl\": \"galician\",\n",
    "    \"mr\": \"marathi\",\n",
    "    \"pa\": \"punjabi\",\n",
    "    \"si\": \"sinhala\",\n",
    "    \"km\": \"khmer\",\n",
    "    \"sn\": \"shona\",\n",
    "    \"yo\": \"yoruba\",\n",
    "    \"so\": \"somali\",\n",
    "    \"af\": \"afrikaans\",\n",
    "    \"oc\": \"occitan\",\n",
    "    \"ka\": \"georgian\",\n",
    "    \"be\": \"belarusian\",\n",
    "    \"tg\": \"tajik\",\n",
    "    \"sd\": \"sindhi\",\n",
    "    \"gu\": \"gujarati\",\n",
    "    \"am\": \"amharic\",\n",
    "    \"yi\": \"yiddish\",\n",
    "    \"lo\": \"lao\",\n",
    "    \"uz\": \"uzbek\",\n",
    "    \"fo\": \"faroese\",\n",
    "    \"ht\": \"haitian creole\",\n",
    "    \"ps\": \"pashto\",\n",
    "    \"tk\": \"turkmen\",\n",
    "    \"nn\": \"nynorsk\",\n",
    "    \"mt\": \"maltese\",\n",
    "    \"sa\": \"sanskrit\",\n",
    "    \"lb\": \"luxembourgish\",\n",
    "    \"my\": \"myanmar\",\n",
    "    \"bo\": \"tibetan\",\n",
    "    \"tl\": \"tagalog\",\n",
    "    \"mg\": \"malagasy\",\n",
    "    \"as\": \"assamese\",\n",
    "    \"tt\": \"tatar\",\n",
    "    \"haw\": \"hawaiian\",\n",
    "    \"ln\": \"lingala\",\n",
    "    \"ha\": \"hausa\",\n",
    "    \"ba\": \"bashkir\",\n",
    "    \"jw\": \"javanese\",\n",
    "    \"su\": \"sundanese\",\n",
    "    \"yue\": \"cantonese\",\n",
    "}\n",
    "\n",
    "# language code lookup by name, with a few language aliases\n",
    "TO_LANGUAGE_CODE = {\n",
    "    **{language: code for code, language in LANGUAGES.items()},\n",
    "    \"burmese\": \"my\",\n",
    "    \"valencian\": \"ca\",\n",
    "    \"flemish\": \"nl\",\n",
    "    \"haitian\": \"ht\",\n",
    "    \"letzeburgesch\": \"lb\",\n",
    "    \"pushto\": \"ps\",\n",
    "    \"panjabi\": \"pa\",\n",
    "    \"moldavian\": \"ro\",\n",
    "    \"moldovan\": \"ro\",\n",
    "    \"sinhalese\": \"si\",\n",
    "    \"castilian\": \"es\",\n",
    "}\n",
    "\n",
    "\n",
    "langs_to_iso = {\n",
    "    \"af\": \"afr\",\n",
    "    \"am\": \"amh\",\n",
    "    \"ar\": \"ara\",\n",
    "    \"as\": \"asm\",\n",
    "    \"az\": \"aze\",\n",
    "    \"ba\": \"bak\",\n",
    "    \"be\": \"bel\",\n",
    "    \"bg\": \"bul\",\n",
    "    \"bn\": \"ben\",\n",
    "    \"bo\": \"tib\",\n",
    "    \"br\": \"bre\",\n",
    "    \"bs\": \"bos\",\n",
    "    \"ca\": \"cat\",\n",
    "    \"cs\": \"cze\",\n",
    "    \"cy\": \"wel\",\n",
    "    \"da\": \"dan\",\n",
    "    \"de\": \"ger\",\n",
    "    \"el\": \"gre\",\n",
    "    \"en\": \"eng\",\n",
    "    \"es\": \"spa\",\n",
    "    \"et\": \"est\",\n",
    "    \"eu\": \"baq\",\n",
    "    \"fa\": \"per\",\n",
    "    \"fi\": \"fin\",\n",
    "    \"fo\": \"fao\",\n",
    "    \"fr\": \"fre\",\n",
    "    \"gl\": \"glg\",\n",
    "    \"gu\": \"guj\",\n",
    "    \"ha\": \"hau\",\n",
    "    \"haw\": \"haw\",\n",
    "    \"he\": \"heb\",\n",
    "    \"hi\": \"hin\",\n",
    "    \"hr\": \"hrv\",\n",
    "    \"ht\": \"hat\",\n",
    "    \"hu\": \"hun\",\n",
    "    \"hy\": \"arm\",\n",
    "    \"id\": \"ind\",\n",
    "    \"is\": \"ice\",\n",
    "    \"it\": \"ita\",\n",
    "    \"ja\": \"jpn\",\n",
    "    \"jw\": \"jav\",\n",
    "    \"ka\": \"geo\",\n",
    "    \"kk\": \"kaz\",\n",
    "    \"km\": \"khm\",\n",
    "    \"kn\": \"kan\",\n",
    "    \"ko\": \"kor\",\n",
    "    \"la\": \"lat\",\n",
    "    \"lb\": \"ltz\",\n",
    "    \"ln\": \"lin\",\n",
    "    \"lo\": \"lao\",\n",
    "    \"lt\": \"lit\",\n",
    "    \"lv\": \"lav\",\n",
    "    \"mg\": \"mlg\",\n",
    "    \"mi\": \"mao\",\n",
    "    \"mk\": \"mac\",\n",
    "    \"ml\": \"mal\",\n",
    "    \"mn\": \"mon\",\n",
    "    \"mr\": \"mar\",\n",
    "    \"ms\": \"may\",\n",
    "    \"mt\": \"mlt\",\n",
    "    \"my\": \"bur\",\n",
    "    \"ne\": \"nep\",\n",
    "    \"nl\": \"dut\",\n",
    "    \"nn\": \"nno\",\n",
    "    \"no\": \"nor\",\n",
    "    \"oc\": \"oci\",\n",
    "    \"pa\": \"pan\",\n",
    "    \"pl\": \"pol\",\n",
    "    \"ps\": \"pus\",\n",
    "    \"pt\": \"por\",\n",
    "    \"ro\": \"rum\",\n",
    "    \"ru\": \"rus\",\n",
    "    \"sa\": \"san\",\n",
    "    \"sd\": \"snd\",\n",
    "    \"si\": \"sin\",\n",
    "    \"sk\": \"slo\",\n",
    "    \"sl\": \"slv\",\n",
    "    \"sn\": \"sna\",\n",
    "    \"so\": \"som\",\n",
    "    \"sq\": \"alb\",\n",
    "    \"sr\": \"srp\",\n",
    "    \"su\": \"sun\",\n",
    "    \"sv\": \"swe\",\n",
    "    \"sw\": \"swa\",\n",
    "    \"ta\": \"tam\",\n",
    "    \"te\": \"tel\",\n",
    "    \"tg\": \"tgk\",\n",
    "    \"th\": \"tha\",\n",
    "    \"tk\": \"tuk\",\n",
    "    \"tl\": \"tgl\",\n",
    "    \"tr\": \"tur\",\n",
    "    \"tt\": \"tat\",\n",
    "    \"uk\": \"ukr\",\n",
    "    \"ur\": \"urd\",\n",
    "    \"uz\": \"uzb\",\n",
    "    \"vi\": \"vie\",\n",
    "    \"yi\": \"yid\",\n",
    "    \"yo\": \"yor\",\n",
    "    \"yue\": \"yue\",\n",
    "    \"zh\": \"chi\",\n",
    "}\n",
    "\n",
    "\n",
    "whisper_langs = sorted(LANGUAGES.keys()) + sorted(\n",
    "    [k.title() for k in TO_LANGUAGE_CODE.keys()]\n",
    ")\n",
    "\n",
    "\n",
    "def create_config(output_dir):\n",
    "    DOMAIN_TYPE = \"telephonic\"  # Can be meeting, telephonic, or general based on domain type of the audio file\n",
    "    CONFIG_FILE_NAME = f\"diar_infer_{DOMAIN_TYPE}.yaml\"\n",
    "    CONFIG_URL = f\"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/inference/{CONFIG_FILE_NAME}\"\n",
    "    MODEL_CONFIG = os.path.join(output_dir, CONFIG_FILE_NAME)\n",
    "    if not os.path.exists(MODEL_CONFIG):\n",
    "        MODEL_CONFIG = wget.download(CONFIG_URL, output_dir)\n",
    "\n",
    "    config = OmegaConf.load(MODEL_CONFIG)\n",
    "\n",
    "    data_dir = os.path.join(output_dir, \"data\")\n",
    "    os.makedirs(data_dir, exist_ok=True)\n",
    "\n",
    "    meta = {\n",
    "        \"audio_filepath\": os.path.join(output_dir, \"mono_file.wav\"),\n",
    "        \"offset\": 0,\n",
    "        \"duration\": None,\n",
    "        \"label\": \"infer\",\n",
    "        \"text\": \"-\",\n",
    "        \"rttm_filepath\": None,\n",
    "        \"uem_filepath\": None,\n",
    "    }\n",
    "    with open(os.path.join(data_dir, \"input_manifest.json\"), \"w\") as fp:\n",
    "        json.dump(meta, fp)\n",
    "        fp.write(\"\\n\")\n",
    "\n",
    "    pretrained_vad = \"vad_multilingual_marblenet\"\n",
    "    pretrained_speaker_model = \"titanet_large\"\n",
    "    config.num_workers = 0  # Workaround for multiprocessing hanging with ipython issue\n",
    "    config.diarizer.manifest_filepath = os.path.join(data_dir, \"input_manifest.json\")\n",
    "    config.diarizer.out_dir = (\n",
    "        output_dir  # Directory to store intermediate files and prediction outputs\n",
    "    )\n",
    "\n",
    "    config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model\n",
    "    config.diarizer.oracle_vad = (\n",
    "        False  # compute VAD provided with model_path to vad config\n",
    "    )\n",
    "    config.diarizer.clustering.parameters.oracle_num_speakers = False\n",
    "\n",
    "    # Here, we use our in-house pretrained NeMo VAD model\n",
    "    config.diarizer.vad.model_path = pretrained_vad\n",
    "    config.diarizer.vad.parameters.onset = 0.8\n",
    "    config.diarizer.vad.parameters.offset = 0.6\n",
    "    config.diarizer.vad.parameters.pad_offset = -0.05\n",
    "    config.diarizer.msdd_model.model_path = (\n",
    "        \"diar_msdd_telephonic\"  # Telephonic speaker diarization model\n",
    "    )\n",
    "\n",
    "    return config\n",
    "\n",
    "\n",
    "def get_word_ts_anchor(s, e, option=\"start\"):\n",
    "    if option == \"end\":\n",
    "        return e\n",
    "    elif option == \"mid\":\n",
    "        return (s + e) / 2\n",
    "    return s\n",
    "\n",
    "\n",
    "def get_words_speaker_mapping(wrd_ts, spk_ts, word_anchor_option=\"start\"):\n",
    "    s, e, sp = spk_ts[0]\n",
    "    wrd_pos, turn_idx = 0, 0\n",
    "    wrd_spk_mapping = []\n",
    "    for wrd_dict in wrd_ts:\n",
    "        ws, we, wrd = (\n",
    "            int(wrd_dict[\"start\"] * 1000),\n",
    "            int(wrd_dict[\"end\"] * 1000),\n",
    "            wrd_dict[\"text\"],\n",
    "        )\n",
    "        wrd_pos = get_word_ts_anchor(ws, we, word_anchor_option)\n",
    "        while wrd_pos > float(e):\n",
    "            turn_idx += 1\n",
    "            turn_idx = min(turn_idx, len(spk_ts) - 1)\n",
    "            s, e, sp = spk_ts[turn_idx]\n",
    "            if turn_idx == len(spk_ts) - 1:\n",
    "                e = get_word_ts_anchor(ws, we, option=\"end\")\n",
    "        wrd_spk_mapping.append(\n",
    "            {\"word\": wrd, \"start_time\": ws, \"end_time\": we, \"speaker\": sp}\n",
    "        )\n",
    "    return wrd_spk_mapping\n",
    "\n",
    "\n",
    "sentence_ending_punctuations = \".?!\"\n",
    "\n",
    "\n",
    "def get_first_word_idx_of_sentence(word_idx, word_list, speaker_list, max_words):\n",
    "    is_word_sentence_end = (\n",
    "        lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations\n",
    "    )\n",
    "    left_idx = word_idx\n",
    "    while (\n",
    "        left_idx > 0\n",
    "        and word_idx - left_idx < max_words\n",
    "        and speaker_list[left_idx - 1] == speaker_list[left_idx]\n",
    "        and not is_word_sentence_end(left_idx - 1)\n",
    "    ):\n",
    "        left_idx -= 1\n",
    "\n",
    "    return left_idx if left_idx == 0 or is_word_sentence_end(left_idx - 1) else -1\n",
    "\n",
    "\n",
    "def get_last_word_idx_of_sentence(word_idx, word_list, max_words):\n",
    "    is_word_sentence_end = (\n",
    "        lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations\n",
    "    )\n",
    "    right_idx = word_idx\n",
    "    while (\n",
    "        right_idx < len(word_list) - 1\n",
    "        and right_idx - word_idx < max_words\n",
    "        and not is_word_sentence_end(right_idx)\n",
    "    ):\n",
    "        right_idx += 1\n",
    "\n",
    "    return (\n",
    "        right_idx\n",
    "        if right_idx == len(word_list) - 1 or is_word_sentence_end(right_idx)\n",
    "        else -1\n",
    "    )\n",
    "\n",
    "\n",
    "def get_realigned_ws_mapping_with_punctuation(\n",
    "    word_speaker_mapping, max_words_in_sentence=50\n",
    "):\n",
    "    is_word_sentence_end = (\n",
    "        lambda x: x >= 0\n",
    "        and word_speaker_mapping[x][\"word\"][-1] in sentence_ending_punctuations\n",
    "    )\n",
    "    wsp_len = len(word_speaker_mapping)\n",
    "\n",
    "    words_list, speaker_list = [], []\n",
    "    for k, line_dict in enumerate(word_speaker_mapping):\n",
    "        word, speaker = line_dict[\"word\"], line_dict[\"speaker\"]\n",
    "        words_list.append(word)\n",
    "        speaker_list.append(speaker)\n",
    "\n",
    "    k = 0\n",
    "    while k < len(word_speaker_mapping):\n",
    "        line_dict = word_speaker_mapping[k]\n",
    "        if (\n",
    "            k < wsp_len - 1\n",
    "            and speaker_list[k] != speaker_list[k + 1]\n",
    "            and not is_word_sentence_end(k)\n",
    "        ):\n",
    "            left_idx = get_first_word_idx_of_sentence(\n",
    "                k, words_list, speaker_list, max_words_in_sentence\n",
    "            )\n",
    "            right_idx = (\n",
    "                get_last_word_idx_of_sentence(\n",
    "                    k, words_list, max_words_in_sentence - k + left_idx - 1\n",
    "                )\n",
    "                if left_idx > -1\n",
    "                else -1\n",
    "            )\n",
    "            if min(left_idx, right_idx) == -1:\n",
    "                k += 1\n",
    "                continue\n",
    "\n",
    "            spk_labels = speaker_list[left_idx : right_idx + 1]\n",
    "            mod_speaker = max(set(spk_labels), key=spk_labels.count)\n",
    "            if spk_labels.count(mod_speaker) < len(spk_labels) // 2:\n",
    "                k += 1\n",
    "                continue\n",
    "\n",
    "            speaker_list[left_idx : right_idx + 1] = [mod_speaker] * (\n",
    "                right_idx - left_idx + 1\n",
    "            )\n",
    "            k = right_idx\n",
    "\n",
    "        k += 1\n",
    "\n",
    "    k, realigned_list = 0, []\n",
    "    while k < len(word_speaker_mapping):\n",
    "        line_dict = word_speaker_mapping[k].copy()\n",
    "        line_dict[\"speaker\"] = speaker_list[k]\n",
    "        realigned_list.append(line_dict)\n",
    "        k += 1\n",
    "\n",
    "    return realigned_list\n",
    "\n",
    "\n",
    "def get_sentences_speaker_mapping(word_speaker_mapping, spk_ts):\n",
    "    sentence_checker = nltk.tokenize.PunktSentenceTokenizer().text_contains_sentbreak\n",
    "    s, e, spk = spk_ts[0]\n",
    "    prev_spk = spk\n",
    "\n",
    "    snts = []\n",
    "    snt = {\"speaker\": f\"Speaker {spk}\", \"start_time\": s, \"end_time\": e, \"text\": \"\"}\n",
    "\n",
    "    for wrd_dict in word_speaker_mapping:\n",
    "        wrd, spk = wrd_dict[\"word\"], wrd_dict[\"speaker\"]\n",
    "        s, e = wrd_dict[\"start_time\"], wrd_dict[\"end_time\"]\n",
    "        if spk != prev_spk or sentence_checker(snt[\"text\"] + \" \" + wrd):\n",
    "            snts.append(snt)\n",
    "            snt = {\n",
    "                \"speaker\": f\"Speaker {spk}\",\n",
    "                \"start_time\": s,\n",
    "                \"end_time\": e,\n",
    "                \"text\": \"\",\n",
    "            }\n",
    "        else:\n",
    "            snt[\"end_time\"] = e\n",
    "        snt[\"text\"] += wrd + \" \"\n",
    "        prev_spk = spk\n",
    "\n",
    "    snts.append(snt)\n",
    "    return snts\n",
    "\n",
    "\n",
    "def get_speaker_aware_transcript(sentences_speaker_mapping, f):\n",
    "    previous_speaker = sentences_speaker_mapping[0][\"speaker\"]\n",
    "    f.write(f\"{previous_speaker}: \")\n",
    "\n",
    "    for sentence_dict in sentences_speaker_mapping:\n",
    "        speaker = sentence_dict[\"speaker\"]\n",
    "        sentence = sentence_dict[\"text\"]\n",
    "\n",
    "        # If this speaker doesn't match the previous one, start a new paragraph\n",
    "        if speaker != previous_speaker:\n",
    "            f.write(f\"\\n\\n{speaker}: \")\n",
    "            previous_speaker = speaker\n",
    "\n",
    "        # No matter what, write the current sentence\n",
    "        f.write(sentence + \" \")\n",
    "\n",
    "\n",
    "def format_timestamp(\n",
    "    milliseconds: float, always_include_hours: bool = False, decimal_marker: str = \".\"\n",
    "):\n",
    "    assert milliseconds >= 0, \"non-negative timestamp expected\"\n",
    "\n",
    "    hours = milliseconds // 3_600_000\n",
    "    milliseconds -= hours * 3_600_000\n",
    "\n",
    "    minutes = milliseconds // 60_000\n",
    "    milliseconds -= minutes * 60_000\n",
    "\n",
    "    seconds = milliseconds // 1_000\n",
    "    milliseconds -= seconds * 1_000\n",
    "\n",
    "    hours_marker = f\"{hours:02d}:\" if always_include_hours or hours > 0 else \"\"\n",
    "    return (\n",
    "        f\"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}\"\n",
    "    )\n",
    "\n",
    "\n",
    "def write_srt(transcript, file):\n",
    "    \"\"\"\n",
    "    Write a transcript to a file in SRT format.\n",
    "\n",
    "    \"\"\"\n",
    "    for i, segment in enumerate(transcript, start=1):\n",
    "        # write srt lines\n",
    "        print(\n",
    "            f\"{i}\\n\"\n",
    "            f\"{format_timestamp(segment['start_time'], always_include_hours=True, decimal_marker=',')} --> \"\n",
    "            f\"{format_timestamp(segment['end_time'], always_include_hours=True, decimal_marker=',')}\\n\"\n",
    "            f\"{segment['speaker']}: {segment['text'].strip().replace('-->', '->')}\\n\",\n",
    "            file=file,\n",
    "            flush=True,\n",
    "        )\n",
    "\n",
    "\n",
    "def find_numeral_symbol_tokens(tokenizer):\n",
    "    numeral_symbol_tokens = [\n",
    "        -1,\n",
    "    ]\n",
    "    for token, token_id in tokenizer.get_vocab().items():\n",
    "        has_numeral_symbol = any(c in \"0123456789%$£\" for c in token)\n",
    "        if has_numeral_symbol:\n",
    "            numeral_symbol_tokens.append(token_id)\n",
    "    return numeral_symbol_tokens\n",
    "\n",
    "\n",
    "def _get_next_start_timestamp(word_timestamps, current_word_index, final_timestamp):\n",
    "    # if current word is the last word\n",
    "    if current_word_index == len(word_timestamps) - 1:\n",
    "        return word_timestamps[current_word_index][\"start\"]\n",
    "\n",
    "    next_word_index = current_word_index + 1\n",
    "    while current_word_index < len(word_timestamps) - 1:\n",
    "        if word_timestamps[next_word_index].get(\"start\") is None:\n",
    "            # if next word doesn't have a start timestamp\n",
    "            # merge it with the current word and delete it\n",
    "            word_timestamps[current_word_index][\"word\"] += (\n",
    "                \" \" + word_timestamps[next_word_index][\"word\"]\n",
    "            )\n",
    "\n",
    "            word_timestamps[next_word_index][\"word\"] = None\n",
    "            next_word_index += 1\n",
    "            if next_word_index == len(word_timestamps):\n",
    "                return final_timestamp\n",
    "\n",
    "        else:\n",
    "            return word_timestamps[next_word_index][\"start\"]\n",
    "\n",
    "\n",
    "def filter_missing_timestamps(\n",
    "    word_timestamps, initial_timestamp=0, final_timestamp=None\n",
    "):\n",
    "    # handle the first and last word\n",
    "    if word_timestamps[0].get(\"start\") is None:\n",
    "        word_timestamps[0][\"start\"] = (\n",
    "            initial_timestamp if initial_timestamp is not None else 0\n",
    "        )\n",
    "        word_timestamps[0][\"end\"] = _get_next_start_timestamp(\n",
    "            word_timestamps, 0, final_timestamp\n",
    "        )\n",
    "\n",
    "    result = [\n",
    "        word_timestamps[0],\n",
    "    ]\n",
    "\n",
    "    for i, ws in enumerate(word_timestamps[1:], start=1):\n",
    "        # if ws doesn't have a start and end\n",
    "        # use the previous end as start and next start as end\n",
    "        if ws.get(\"start\") is None and ws.get(\"word\") is not None:\n",
    "            ws[\"start\"] = word_timestamps[i - 1][\"end\"]\n",
    "            ws[\"end\"] = _get_next_start_timestamp(word_timestamps, i, final_timestamp)\n",
    "\n",
    "        if ws[\"word\"] is not None:\n",
    "            result.append(ws)\n",
    "    return result\n",
    "\n",
    "\n",
    "def cleanup(path: str):\n",
    "    \"\"\"path could either be relative or absolute.\"\"\"\n",
    "    # check if file or directory exists\n",
    "    if os.path.isfile(path) or os.path.islink(path):\n",
    "        # remove file\n",
    "        os.remove(path)\n",
    "    elif os.path.isdir(path):\n",
    "        # remove directory and all its content\n",
    "        shutil.rmtree(path)\n",
    "    else:\n",
    "        raise ValueError(\"Path {} is not a file or dir.\".format(path))\n",
    "\n",
    "\n",
    "def process_language_arg(language: str, model_name: str):\n",
    "    \"\"\"\n",
    "    Process the language argument to make sure it's valid and convert language names to language codes.\n",
    "    \"\"\"\n",
    "    if language is not None:\n",
    "        language = language.lower()\n",
    "    if language not in LANGUAGES:\n",
    "        if language in TO_LANGUAGE_CODE:\n",
    "            language = TO_LANGUAGE_CODE[language]\n",
    "        else:\n",
    "            raise ValueError(f\"Unsupported language: {language}\")\n",
    "\n",
    "    if model_name.endswith(\".en\") and language != \"en\":\n",
    "        if language is not None:\n",
    "            logging.warning(\n",
    "                f\"{model_name} is an English-only model but received '{language}'; using English instead.\"\n",
    "            )\n",
    "        language = \"en\"\n",
    "    return language"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "B7qWQb--1Xcw"
   },
   "source": [
    "# Options"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ONlFrSnD0FOp"
   },
   "outputs": [],
   "source": [
    "# Name of the audio file\n",
    "audio_path = \"20200128-Pieter Wuille (part 1 of 2) - Episode 1.mp3\"\n",
    "\n",
    "# Whether to enable music removal from speech, helps increase diarization quality but uses alot of ram\n",
    "enable_stemming = True\n",
    "\n",
    "# (choose from 'tiny.en', 'tiny', 'base.en', 'base', 'small.en', 'small', 'medium.en', 'medium', 'large-v1', 'large-v2', 'large-v3', 'large', 'turbo')\n",
    "whisper_model_name = \"large-v2\"\n",
    "\n",
    "# replaces numerical digits with their pronounciation, increases diarization accuracy\n",
    "suppress_numerals = True\n",
    "\n",
    "batch_size = 8\n",
    "\n",
    "language = None  # autodetect language\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "h-cY1ZEy2KVI"
   },
   "source": [
    "# Processing"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "7ZS4xXmE2NGP"
   },
   "source": [
    "## Separating music from speech using Demucs\n",
    "\n",
    "---\n",
    "\n",
    "By isolating the vocals from the rest of the audio, it becomes easier to identify and track individual speakers based on the spectral and temporal characteristics of their speech signals. Source separation is just one of many techniques that can be used as a preprocessing step to help improve the accuracy and reliability of the overall diarization process."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "HKcgQUrAzsJZ",
    "outputId": "dc2a1d96-20da-4749-9d64-21edacfba1b1"
   },
   "outputs": [],
   "source": [
    "if enable_stemming:\n",
    "    # Isolate vocals from the rest of the audio\n",
    "\n",
    "    return_code = os.system(\n",
    "        f'python -m demucs.separate -n htdemucs --two-stems=vocals \"{audio_path}\" -o \"temp_outputs\" --device \"{device}\"'\n",
    "    )\n",
    "\n",
    "    if return_code != 0:\n",
    "        logging.warning(\"Source splitting failed, using original audio file.\")\n",
    "        vocal_target = audio_path\n",
    "    else:\n",
    "        vocal_target = os.path.join(\n",
    "            \"temp_outputs\",\n",
    "            \"htdemucs\",\n",
    "            os.path.splitext(os.path.basename(audio_path))[0],\n",
    "            \"vocals.wav\",\n",
    "        )\n",
    "else:\n",
    "    vocal_target = audio_path"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "UYg9VWb22Tz8"
   },
   "source": [
    "## Transcriping audio using Whisper and realligning timestamps using Forced Alignment\n",
    "---\n",
    "This code uses two different open-source models to transcribe speech and perform forced alignment on the resulting transcription.\n",
    "\n",
    "The first model is called OpenAI Whisper, which is a speech recognition model that can transcribe speech with high accuracy. The code loads the whisper model and uses it to transcribe the vocal_target file.\n",
    "\n",
    "The output of the transcription process is a set of text segments with corresponding timestamps indicating when each segment was spoken.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5-VKFn530oTl"
   },
   "outputs": [],
   "source": [
    "compute_type = \"float16\"\n",
    "# or run on GPU with INT8\n",
    "# compute_type = \"int8_float16\"\n",
    "# or run on CPU with INT8\n",
    "# compute_type = \"int8\"\n",
    "\n",
    "whisper_model = faster_whisper.WhisperModel(\n",
    "    whisper_model_name, device=device, compute_type=compute_type\n",
    ")\n",
    "whisper_pipeline = faster_whisper.BatchedInferencePipeline(whisper_model)\n",
    "audio_waveform = faster_whisper.decode_audio(vocal_target)\n",
    "suppress_tokens = (\n",
    "    find_numeral_symbol_tokens(whisper_model.hf_tokenizer)\n",
    "    if suppress_numerals\n",
    "    else [-1]\n",
    ")\n",
    "\n",
    "if batch_size > 0:\n",
    "    transcript_segments, info = whisper_pipeline.transcribe(\n",
    "        audio_waveform,\n",
    "        language,\n",
    "        suppress_tokens=suppress_tokens,\n",
    "        batch_size=batch_size,\n",
    "        without_timestamps=True,\n",
    "    )\n",
    "else:\n",
    "    transcript_segments, info = whisper_model.transcribe(\n",
    "        audio_waveform,\n",
    "        language,\n",
    "        suppress_tokens=suppress_tokens,\n",
    "        without_timestamps=True,\n",
    "        vad_filter=True,\n",
    "    )\n",
    "\n",
    "full_transcript = \"\".join(segment.text for segment in transcript_segments)\n",
    "\n",
    "# clear gpu vram\n",
    "del whisper_model, whisper_pipeline\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Aligning the transcription with the original audio using Forced Alignment\n",
    "---\n",
    "Forced alignment aims to to align the transcription segments with the original audio signal contained in the vocal_target file. This process involves finding the exact timestamps in the audio signal where each segment was spoken and aligning the text accordingly.\n",
    "\n",
    "By combining the outputs of the two models, the code produces a fully aligned transcription of the speech contained in the vocal_target file. This aligned transcription can be useful for a variety of speech processing tasks, such as speaker diarization, sentiment analysis, and language identification."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "alignment_model, alignment_tokenizer = load_alignment_model(\n",
    "    device,\n",
    "    dtype=torch.float16 if device == \"cuda\" else torch.float32,\n",
    ")\n",
    "\n",
    "audio_waveform = (\n",
    "    torch.from_numpy(audio_waveform)\n",
    "    .to(alignment_model.dtype)\n",
    "    .to(alignment_model.device)\n",
    ")\n",
    "\n",
    "emissions, stride = generate_emissions(\n",
    "    alignment_model, audio_waveform, batch_size=batch_size\n",
    ")\n",
    "\n",
    "del alignment_model\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "tokens_starred, text_starred = preprocess_text(\n",
    "    full_transcript,\n",
    "    romanize=True,\n",
    "    language=langs_to_iso[info.language],\n",
    ")\n",
    "\n",
    "segments, scores, blank_token = get_alignments(\n",
    "    emissions,\n",
    "    tokens_starred,\n",
    "    alignment_tokenizer,\n",
    ")\n",
    "\n",
    "spans = get_spans(tokens_starred, segments, blank_token)\n",
    "\n",
    "word_timestamps = postprocess_results(text_starred, spans, stride, scores)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "7EEaJPsQ21Rx"
   },
   "source": [
    "## Convert audio to mono for NeMo combatibility"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ROOT = os.getcwd()\n",
    "temp_path = os.path.join(ROOT, \"temp_outputs\")\n",
    "os.makedirs(temp_path, exist_ok=True)\n",
    "torchaudio.save(\n",
    "    os.path.join(temp_path, \"mono_file.wav\"),\n",
    "    audio_waveform.cpu().unsqueeze(0).float(),\n",
    "    16000,\n",
    "    channels_first=True,\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "D1gkViCf2-CV"
   },
   "source": [
    "## Speaker Diarization using NeMo MSDD Model\n",
    "---\n",
    "This code uses a model called Nvidia NeMo MSDD (Multi-scale Diarization Decoder) to perform speaker diarization on an audio signal. Speaker diarization is the process of separating an audio signal into different segments based on who is speaking at any given time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "C7jIpBCH02RL"
   },
   "outputs": [],
   "source": [
    "# Initialize NeMo MSDD diarization model\n",
    "msdd_model = NeuralDiarizer(cfg=create_config(temp_path)).to(\"cuda\")\n",
    "msdd_model.diarize()\n",
    "\n",
    "del msdd_model\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "NmkZYaDAEOAg"
   },
   "source": [
    "## Mapping Spekers to Sentences According to Timestamps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "E65LUGQe02zw"
   },
   "outputs": [],
   "source": [
    "# Reading timestamps <> Speaker Labels mapping\n",
    "\n",
    "speaker_ts = []\n",
    "with open(os.path.join(temp_path, \"pred_rttms\", \"mono_file.rttm\"), \"r\") as f:\n",
    "    lines = f.readlines()\n",
    "    for line in lines:\n",
    "        line_list = line.split(\" \")\n",
    "        s = int(float(line_list[5]) * 1000)\n",
    "        e = s + int(float(line_list[8]) * 1000)\n",
    "        speaker_ts.append([s, e, int(line_list[11].split(\"_\")[-1])])\n",
    "\n",
    "wsm = get_words_speaker_mapping(word_timestamps, speaker_ts, \"start\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "8Ruxc8S1EXtW"
   },
   "source": [
    "## Realligning Speech segments using Punctuation\n",
    "---\n",
    "\n",
    "This code provides a method for disambiguating speaker labels in cases where a sentence is split between two different speakers. It uses punctuation markings to determine the dominant speaker for each sentence in the transcription.\n",
    "\n",
    "```\n",
    "Speaker A: It's got to come from somewhere else. Yeah, that one's also fun because you know the lows are\n",
    "Speaker B: going to suck, right? So it's actually it hits you on both sides.\n",
    "```\n",
    "\n",
    "For example, if a sentence is split between two speakers, the code takes the mode of speaker labels for each word in the sentence, and uses that speaker label for the whole sentence. This can help to improve the accuracy of speaker diarization, especially in cases where the Whisper model may not take fine utterances like \"hmm\" and \"yeah\" into account, but the Diarization Model (Nemo) may include them, leading to inconsistent results.\n",
    "\n",
    "The code also handles cases where one speaker is giving a monologue while other speakers are making occasional comments in the background. It ignores the comments and assigns the entire monologue to the speaker who is speaking the majority of the time. This provides a robust and reliable method for realigning speech segments to their respective speakers based on punctuation in the transcription."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "pgfC5hA41BXu"
   },
   "outputs": [],
   "source": [
    "if info.language in punct_model_langs:\n",
    "    # restoring punctuation in the transcript to help realign the sentences\n",
    "    punct_model = PunctuationModel(model=\"kredor/punctuate-all\")\n",
    "\n",
    "    words_list = list(map(lambda x: x[\"word\"], wsm))\n",
    "\n",
    "    labled_words = punct_model.predict(words_list, chunk_size=230)\n",
    "\n",
    "    ending_puncts = \".?!\"\n",
    "    model_puncts = \".,;:!?\"\n",
    "\n",
    "    # We don't want to punctuate U.S.A. with a period. Right?\n",
    "    is_acronym = lambda x: re.fullmatch(r\"\\b(?:[a-zA-Z]\\.){2,}\", x)\n",
    "\n",
    "    for word_dict, labeled_tuple in zip(wsm, labled_words):\n",
    "        word = word_dict[\"word\"]\n",
    "        if (\n",
    "            word\n",
    "            and labeled_tuple[1] in ending_puncts\n",
    "            and (word[-1] not in model_puncts or is_acronym(word))\n",
    "        ):\n",
    "            word += labeled_tuple[1]\n",
    "            if word.endswith(\"..\"):\n",
    "                word = word.rstrip(\".\")\n",
    "            word_dict[\"word\"] = word\n",
    "\n",
    "else:\n",
    "    logging.warning(\n",
    "        f\"Punctuation restoration is not available for {info.language} language. Using the original punctuation.\"\n",
    "    )\n",
    "\n",
    "wsm = get_realigned_ws_mapping_with_punctuation(wsm)\n",
    "ssm = get_sentences_speaker_mapping(wsm, speaker_ts)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "vF2QAtLOFvwZ"
   },
   "source": [
    "## Cleanup and Exporing the results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "kFTyKI6B1MI0"
   },
   "outputs": [],
   "source": [
    "with open(f\"{os.path.splitext(audio_path)[0]}.txt\", \"w\", encoding=\"utf-8-sig\") as f:\n",
    "    get_speaker_aware_transcript(ssm, f)\n",
    "\n",
    "with open(f\"{os.path.splitext(audio_path)[0]}.srt\", \"w\", encoding=\"utf-8-sig\") as srt:\n",
    "    write_srt(ssm, srt)\n",
    "\n",
    "cleanup(temp_path)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "authorship_tag": "ABX9TyOyiQNkD+ROzss634BOsrSh",
   "collapsed_sections": [
    "eCmjcOc9yEtQ",
    "jbsUt3SwyhjD"
   ],
   "include_colab_link": true,
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
